{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='1'></a>\n",
    "# 1. Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.models import Model\n",
    "from keras.layers import *\n",
    "from keras.layers.advanced_activations import LeakyReLU\n",
    "from keras.initializers import RandomNormal\n",
    "from pixel_shuffler import PixelShuffler\n",
    "import keras.backend as K\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mtcnn_detect_face\n",
    "from umeyama import umeyama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='2'></a>\n",
    "# 2. Install requirements\n",
    "\n",
    "## ========== CAUTION ========== \n",
    "\n",
    "If you are running this jupyter on local machine. Please read [this blog](http://jakevdp.github.io/blog/2017/12/05/installing-python-packages-from-jupyter/) before running the following cells which pip install packages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://github.com/ageitgey/face_recognition\n",
    "#!pip install face_recognition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install moviepy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='4'></a>\n",
    "# 4. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "K.set_learning_phase(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_SHAPE = (64, 64, 3)\n",
    "nc_in = 3 # number of input channels of generators\n",
    "nc_D_inp = 6 # number of input channels of discriminators\n",
    "\n",
    "use_self_attn = False\n",
    "w_l2 = 1e-4 # weight decay\n",
    "\n",
    "batchSize = 8\n",
    "\n",
    "# Path of training images\n",
    "img_dirA = './faceA/*.*'\n",
    "img_dirB = './faceB/*.*'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='5'></a>\n",
    "# 5. Define models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "conv_init = RandomNormal(0, 0.02)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "class Scale(Layer):\n",
    "    '''\n",
    "    Code borrows from https://github.com/flyyufelix/cnn_finetune\n",
    "    '''\n",
    "    def __init__(self, weights=None, axis=-1, gamma_init='zero', **kwargs):\n",
    "        self.axis = axis\n",
    "        self.gamma_init = initializers.get(gamma_init)\n",
    "        self.initial_weights = weights\n",
    "        super(Scale, self).__init__(**kwargs)\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        self.input_spec = [InputSpec(shape=input_shape)]\n",
    "\n",
    "        # Compatibility with TensorFlow >= 1.0.0\n",
    "        self.gamma = K.variable(self.gamma_init((1,)), name='{}_gamma'.format(self.name))\n",
    "        self.trainable_weights = [self.gamma]\n",
    "\n",
    "        if self.initial_weights is not None:\n",
    "            self.set_weights(self.initial_weights)\n",
    "            del self.initial_weights\n",
    "\n",
    "    def call(self, x, mask=None):\n",
    "        return self.gamma * x\n",
    "\n",
    "    def get_config(self):\n",
    "        config = {\"axis\": self.axis}\n",
    "        base_config = super(Scale, self).get_config()\n",
    "        return dict(list(base_config.items()) + list(config.items()))\n",
    "\n",
    "\n",
    "def self_attn_block(inp, nc):\n",
    "    '''\n",
    "    Code borrows from https://github.com/taki0112/Self-Attention-GAN-Tensorflow\n",
    "    '''\n",
    "    assert nc//8 > 0, f\"Input channels must be >= 8, but got nc={nc}\"\n",
    "    x = inp\n",
    "    shape_x = x.get_shape().as_list()\n",
    "    \n",
    "    f = Conv2D(nc//8, 1, kernel_initializer=conv_init)(x)\n",
    "    g = Conv2D(nc//8, 1, kernel_initializer=conv_init)(x)\n",
    "    h = Conv2D(nc, 1, kernel_initializer=conv_init)(x)\n",
    "    \n",
    "    shape_f = f.get_shape().as_list()\n",
    "    shape_g = g.get_shape().as_list()\n",
    "    shape_h = h.get_shape().as_list()\n",
    "    flat_f = Reshape((-1, shape_f[-1]))(f)\n",
    "    flat_g = Reshape((-1, shape_g[-1]))(g)\n",
    "    flat_h = Reshape((-1, shape_h[-1]))(h)   \n",
    "    \n",
    "    s = Lambda(lambda x: tf.matmul(x[0], x[1], transpose_b=True))([flat_g, flat_f])\n",
    "\n",
    "    beta = Softmax(axis=-1)(s)\n",
    "    o = Lambda(lambda x: tf.matmul(x[0], x[1]))([beta, flat_h])\n",
    "    o = Reshape(shape_x[1:])(o)\n",
    "    o = Scale()(o)\n",
    "    \n",
    "    out = add([o, inp])\n",
    "    return out\n",
    "\n",
    "def conv_block(input_tensor, f):\n",
    "    x = input_tensor\n",
    "    x = Conv2D(f, kernel_size=3, strides=2, kernel_regularizer=regularizers.l2(w_l2),  \n",
    "               kernel_initializer=conv_init, use_bias=False, padding=\"same\")(x)\n",
    "    x = Activation(\"relu\")(x)\n",
    "    return x\n",
    "\n",
    "def conv_block_d(input_tensor, f, use_instance_norm=False):\n",
    "    x = input_tensor\n",
    "    x = Conv2D(f, kernel_size=4, strides=2, kernel_regularizer=regularizers.l2(w_l2), \n",
    "               kernel_initializer=conv_init, use_bias=False, padding=\"same\")(x)\n",
    "    x = LeakyReLU(alpha=0.2)(x)\n",
    "    return x\n",
    "\n",
    "def res_block(input_tensor, f):\n",
    "    x = input_tensor\n",
    "    x = Conv2D(f, kernel_size=3, kernel_regularizer=regularizers.l2(w_l2), \n",
    "               kernel_initializer=conv_init, use_bias=False, padding=\"same\")(x)\n",
    "    x = LeakyReLU(alpha=0.2)(x)\n",
    "    x = Conv2D(f, kernel_size=3, kernel_regularizer=regularizers.l2(w_l2), \n",
    "               kernel_initializer=conv_init, use_bias=False, padding=\"same\")(x)\n",
    "    x = add([x, input_tensor])\n",
    "    x = LeakyReLU(alpha=0.2)(x)\n",
    "    return x\n",
    "\n",
    "def upscale_ps(filters, use_norm=True):\n",
    "    def block(x):\n",
    "        x = Conv2D(filters*4, kernel_size=3, kernel_regularizer=regularizers.l2(w_l2), \n",
    "                   kernel_initializer=RandomNormal(0, 0.02), padding='same')(x)\n",
    "        x = LeakyReLU(0.2)(x)\n",
    "        x = PixelShuffler()(x)\n",
    "        return x\n",
    "    return block\n",
    "\n",
    "def Discriminator(nc_in, input_size=64):\n",
    "    inp = Input(shape=(input_size, input_size, nc_in))\n",
    "    #x = GaussianNoise(0.05)(inp)\n",
    "    x = conv_block_d(inp, 64, False)\n",
    "    x = conv_block_d(x, 128, False)\n",
    "    x = self_attn_block(x, 128) if use_self_attn else x\n",
    "    x = conv_block_d(x, 256, False)\n",
    "    x = self_attn_block(x, 256) if use_self_attn else x\n",
    "    out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding=\"same\")(x)   \n",
    "    return Model(inputs=[inp], outputs=out)\n",
    "\n",
    "def Encoder(nc_in=3, input_size=64):\n",
    "    inp = Input(shape=(input_size, input_size, nc_in))\n",
    "    x = Conv2D(64, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding=\"same\")(inp)\n",
    "    x = conv_block(x,128)\n",
    "    x = conv_block(x,256)\n",
    "    x = self_attn_block(x, 256) if use_self_attn else x\n",
    "    x = conv_block(x,512) \n",
    "    x = self_attn_block(x, 512) if use_self_attn else x\n",
    "    x = conv_block(x,1024)\n",
    "    x = Dense(1024)(Flatten()(x))\n",
    "    x = Dense(4*4*1024)(x)\n",
    "    x = Reshape((4, 4, 1024))(x)\n",
    "    out = upscale_ps(512)(x)\n",
    "    return Model(inputs=inp, outputs=out)\n",
    "\n",
    "def Decoder_ps(nc_in=512, input_size=8):\n",
    "    input_ = Input(shape=(input_size, input_size, nc_in))\n",
    "    x = input_\n",
    "    x = upscale_ps(256)(x)\n",
    "    x = upscale_ps(128)(x)\n",
    "    x = self_attn_block(x, 128) if use_self_attn else x\n",
    "    x = upscale_ps(64)(x)\n",
    "    x = res_block(x, 64)\n",
    "    x = self_attn_block(x, 64) if use_self_attn else x\n",
    "    #x = Conv2D(4, kernel_size=5, padding='same')(x)   \n",
    "    alpha = Conv2D(1, kernel_size=5, padding='same', activation=\"sigmoid\")(x)\n",
    "    rgb = Conv2D(3, kernel_size=5, padding='same', activation=\"tanh\")(x)\n",
    "    out = concatenate([alpha, rgb])\n",
    "    return Model(input_, out)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = Encoder()\n",
    "decoder_A = Decoder_ps()\n",
    "decoder_B = Decoder_ps()\n",
    "\n",
    "x = Input(shape=IMAGE_SHAPE)\n",
    "\n",
    "netGA = Model(x, decoder_A(encoder(x)))\n",
    "netGB = Model(x, decoder_B(encoder(x)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='6'></a>\n",
    "# 6. Load Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    encoder.load_weights(\"models/encoder.h5\")\n",
    "    decoder_A.load_weights(\"models/decoder_A.h5\")\n",
    "    decoder_B.load_weights(\"models/decoder_B.h5\")\n",
    "    print (\"Model weights files are successfully loaded\")\n",
    "except:\n",
    "    print (\"!! Error occurs during loading weights files.\")\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='7'></a>\n",
    "# 7. Define Inputs/Outputs Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cycle_variables(netG):\n",
    "    distorted_input = netG.inputs[0]\n",
    "    fake_output = netG.outputs[0]\n",
    "    alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output)\n",
    "    rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output)\n",
    "    \n",
    "    masked_fake_output = alpha * rgb + (1-alpha) * distorted_input \n",
    "\n",
    "    fn_generate = K.function([distorted_input], [masked_fake_output])\n",
    "    fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])])\n",
    "    fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])])\n",
    "    return fn_generate, fn_mask, fn_abgr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_A, path_mask_A, path_abgr_A = cycle_variables(netGA)\n",
    "path_B, path_mask_B, path_abgr_B = cycle_variables(netGB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='12'></a>\n",
    "# 12. Import modules for video making\n",
    "\n",
    "Given a video as input, the following cells will detect face for each frame using dlib's cnn model. And use trained GAN model to transform detected face into target face. Then output a video with swapped faces."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Download ffmpeg if need, which is required by moviepy.\n",
    "\n",
    "#import imageio\n",
    "#imageio.plugins.ffmpeg.download()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from moviepy.editor import VideoFileClip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define transform path: A2B or B2A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "direction = \"BtoA\" # default trainsforming faceB to faceA\n",
    "\n",
    "if direction is \"AtoB\":\n",
    "    path_func = path_abgr_B\n",
    "elif direction is \"BtoA\":\n",
    "    path_func = path_abgr_A\n",
    "else:\n",
    "    print (\"direction should be either AtoB or BtoA\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MTCNN setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_mtcnn(sess, model_path):\n",
    "    if not model_path:\n",
    "        model_path,_ = os.path.split(os.path.realpath(__file__))\n",
    "\n",
    "    with tf.variable_scope('pnet2'):\n",
    "        data = tf.placeholder(tf.float32, (None,None,None,3), 'input')\n",
    "        pnet = mtcnn_detect_face.PNet({'data':data})\n",
    "        pnet.load(os.path.join(model_path, 'det1.npy'), sess)\n",
    "    with tf.variable_scope('rnet2'):\n",
    "        data = tf.placeholder(tf.float32, (None,24,24,3), 'input')\n",
    "        rnet = mtcnn_detect_face.RNet({'data':data})\n",
    "        rnet.load(os.path.join(model_path, 'det2.npy'), sess)\n",
    "    with tf.variable_scope('onet2'):\n",
    "        data = tf.placeholder(tf.float32, (None,48,48,3), 'input')\n",
    "        onet = mtcnn_detect_face.ONet({'data':data})\n",
    "        onet.load(os.path.join(model_path, 'det3.npy'), sess)\n",
    "    return pnet, rnet, onet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "WEIGHTS_PATH = \"./mtcnn_weights/\"\n",
    "\n",
    "sess = K.get_session()\n",
    "with sess.as_default():\n",
    "    global pnet, rnet, onet \n",
    "    pnet2, rnet2, onet2 = create_mtcnn(sess, WEIGHTS_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "global pnet, rnet, onet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "pnet_fun = K.function([pnet2.layers['data']],[pnet2.layers['conv4-2'], pnet2.layers['prob1']])\n",
    "rnet_fun = K.function([rnet2.layers['data']],[rnet2.layers['conv5-2'], rnet2.layers['prob1']])\n",
    "onet_fun = K.function([onet2.layers['data']], [onet2.layers['conv6-2'], onet2.layers['conv6-3'], onet2.layers['prob1']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.variable_scope('pnet2', reuse=True):\n",
    "    data = tf.placeholder(tf.float32, (None,None,None,3), 'input')\n",
    "    pnet2 = mtcnn_detect_face.PNet({'data':data})\n",
    "    pnet2.load(os.path.join(\"./mtcnn_weights/\", 'det1.npy'), sess)\n",
    "with tf.variable_scope('rnet2', reuse=True):\n",
    "    data = tf.placeholder(tf.float32, (None,24,24,3), 'input')\n",
    "    rnet2 = mtcnn_detect_face.RNet({'data':data})\n",
    "    rnet2.load(os.path.join(\"./mtcnn_weights/\", 'det2.npy'), sess)\n",
    "with tf.variable_scope('onet2', reuse=True):\n",
    "    data = tf.placeholder(tf.float32, (None,48,48,3), 'input')\n",
    "    onet2 = mtcnn_detect_face.ONet({'data':data})\n",
    "    onet2.load(os.path.join(\"./mtcnn_weights/\", 'det3.npy'), sess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "pnet = K.function([pnet2.layers['data']],[pnet2.layers['conv4-2'], pnet2.layers['prob1']])\n",
    "rnet = K.function([rnet2.layers['data']],[rnet2.layers['conv5-2'], rnet2.layers['prob1']])\n",
    "onet = K.function([onet2.layers['data']], [onet2.layers['conv6-2'], onet2.layers['conv6-3'], onet2.layers['prob1']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='13'></a>\n",
    "# 13. Make video clips\n",
    "\n",
    "### Default transform: face B to face A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_smoothed_mask = True\n",
    "use_smoothed_bbox = True\n",
    "\n",
    "def is_overlap(box1, box2):\n",
    "    overlap_x0 = np.max([box1[0], box2[0]]).astype(np.float32)\n",
    "    overlap_y1 = np.min([box1[1], box2[1]]).astype(np.float32)\n",
    "    overlap_x1 = np.min([box1[2], box2[2]]).astype(np.float32)\n",
    "    overlap_y0 = np.max([box1[3], box2[3]]).astype(np.float32)\n",
    "    area_iou = (overlap_x1-overlap_x0) * (overlap_y1-overlap_y0)\n",
    "    area_box1 = (box1[2]-box1[0]) * (box1[1]-box1[3])\n",
    "    area_box2 = (box2[2]-box2[0]) * (box2[1]-box2[3])    \n",
    "    return (area_iou / area_box1) >= 0.2\n",
    "    \n",
    "def remove_overlaps(faces):    \n",
    "    main_face = get_most_conf_face(faces)\n",
    "    main_face_bbox = main_face[0]\n",
    "    result_faces = []\n",
    "    result_faces.append(main_face_bbox)\n",
    "    for (x0, y1, x1, y0, conf_score) in faces:\n",
    "        if not is_overlap(main_face_bbox, (x0, y1, x1, y0)):\n",
    "            result_faces.append((x0, y1, x1, y0, conf_score))\n",
    "    return result_faces\n",
    "\n",
    "def get_most_conf_face(faces):\n",
    "    # Return the bbox w/ the highest confidence score\n",
    "    best_conf_score = 0\n",
    "    conf_face = None\n",
    "    for (x0, y1, x1, y0, conf_score) in faces: \n",
    "        if conf_score >= best_conf_score:\n",
    "            best_conf_score = conf_score\n",
    "            conf_face = [(x0, y1, x1, y0, conf_score)]\n",
    "    return conf_face\n",
    "\n",
    "def kalmanfilter_init(noise_coef):\n",
    "    kf = cv2.KalmanFilter(4,2)\n",
    "    kf.measurementMatrix = np.array([[1,0,0,0],[0,1,0,0]], np.float32)\n",
    "    kf.transitionMatrix = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]], np.float32)\n",
    "    kf.processNoiseCov = noise_coef * np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]], np.float32)\n",
    "    return kf\n",
    "\n",
    "def is_higher_than_480p(x):\n",
    "    return (x.shape[0] * x.shape[1]) >= (858*480)\n",
    "\n",
    "def is_higher_than_720p(x):\n",
    "    return (x.shape[0] * x.shape[1]) >= (1280*720)\n",
    "\n",
    "def is_higher_than_1080p(x):\n",
    "    return (x.shape[0] * x.shape[1]) >= (1920*1080)\n",
    "\n",
    "def calibrate_coord(faces, video_scaling_factor):\n",
    "    for i, (x0, y1, x1, y0, _) in enumerate(faces):\n",
    "        faces[i] = (x0*video_scaling_factor, y1*video_scaling_factor, \n",
    "                    x1*video_scaling_factor, y0*video_scaling_factor, _)\n",
    "    return faces\n",
    "\n",
    "def process_mtcnn_bbox(bboxes, im_shape):\n",
    "    # outuut bbox coordinate of MTCNN is (y0, x0, y1, x1)\n",
    "    # Process the bbox coord. to a square bbox with ordering (x0, y1, x1, y0)\n",
    "    for i, bbox in enumerate(bboxes):\n",
    "        y0, x0, y1, x1 = bboxes[i,0:4]\n",
    "        w = int(y1 - y0)\n",
    "        h = int(x1 - x0)\n",
    "        length = (w + h)/2\n",
    "        center = (int((x1+x0)/2),int((y1+y0)/2))\n",
    "        new_x0 = np.max([0, (center[0]-length//2)])#.astype(np.int32)\n",
    "        new_x1 = np.min([im_shape[0], (center[0]+length//2)])#.astype(np.int32)\n",
    "        new_y0 = np.max([0, (center[1]-length//2)])#.astype(np.int32)\n",
    "        new_y1 = np.min([im_shape[1], (center[1]+length//2)])#.astype(np.int32)\n",
    "        bboxes[i,0:4] = new_x0, new_y1, new_x1, new_y0\n",
    "    return bboxes\n",
    "\n",
    "def get_faces_bbox(image):  \n",
    "    global pnet, rnet, onet \n",
    "    global detec_threshold\n",
    "    minsize = 20 # minimum size of face\n",
    "    threshold = [ 0.6, 0.7, detec_threshold ]  # three steps's threshold\n",
    "    factor = 0.709 # scale factor\n",
    "    if manually_downscale:\n",
    "        video_scaling_factor = manually_downscale_factor\n",
    "        resized_image = cv2.resize(image, \n",
    "                                   (image.shape[1]//video_scaling_factor, \n",
    "                                    image.shape[0]//video_scaling_factor))\n",
    "        faces, pnts = mtcnn_detect_face.detect_face(resized_image, minsize, pnet, rnet, onet, threshold, factor)\n",
    "        faces = process_mtcnn_bbox(faces, resized_image.shape)\n",
    "        faces = calibrate_coord(faces, video_scaling_factor)\n",
    "    elif is_higher_than_1080p(image):\n",
    "        video_scaling_factor = 4 + video_scaling_offset\n",
    "        resized_image = cv2.resize(image, \n",
    "                                   (image.shape[1]//video_scaling_factor, \n",
    "                                    image.shape[0]//video_scaling_factor))\n",
    "        faces, pnts = mtcnn_detect_face.detect_face(resized_image, minsize, pnet, rnet, onet, threshold, factor)\n",
    "        faces = process_mtcnn_bbox(faces, resized_image.shape)\n",
    "        faces = calibrate_coord(faces, video_scaling_factor)\n",
    "    elif is_higher_than_720p(image):\n",
    "        video_scaling_factor = 3 + video_scaling_offset\n",
    "        resized_image = cv2.resize(image, \n",
    "                                   (image.shape[1]//video_scaling_factor, \n",
    "                                    image.shape[0]//video_scaling_factor))\n",
    "        faces, pnts = mtcnn_detect_face.detect_face(resized_image, minsize, pnet, rnet, onet, threshold, factor)\n",
    "        faces = process_mtcnn_bbox(faces, resized_image.shape)\n",
    "        faces = calibrate_coord(faces, video_scaling_factor)  \n",
    "    elif is_higher_than_480p(image):\n",
    "        video_scaling_factor = 2 + video_scaling_offset\n",
    "        resized_image = cv2.resize(image, \n",
    "                                   (image.shape[1]//video_scaling_factor, \n",
    "                                    image.shape[0]//video_scaling_factor))\n",
    "        faces, pnts = mtcnn_detect_face.detect_face(resized_image, minsize, pnet, rnet, onet, threshold, factor)\n",
    "        faces = process_mtcnn_bbox(faces, resized_image.shape)\n",
    "        faces = calibrate_coord(faces, video_scaling_factor)\n",
    "    else:\n",
    "        faces, pnts = mtcnn_detect_face.detect_face(image, minsize, pnet, rnet, onet, threshold, factor)\n",
    "        faces = process_mtcnn_bbox(faces, image.shape)\n",
    "    return faces\n",
    "\n",
    "def get_smoothed_coord(x0, x1, y0, y1, shape, ratio=0.65):\n",
    "    global prev_x0, prev_x1, prev_y0, prev_y1\n",
    "    if not use_kalman_filter:\n",
    "        x0 = int(ratio * prev_x0 + (1-ratio) * x0)\n",
    "        x1 = int(ratio * prev_x1 + (1-ratio) * x1)\n",
    "        y1 = int(ratio * prev_y1 + (1-ratio) * y1)\n",
    "        y0 = int(ratio * prev_y0 + (1-ratio) * y0)\n",
    "    else:\n",
    "        x0y0 = np.array([x0, y0]).astype(np.float32)\n",
    "        x1y1 = np.array([x1, y1]).astype(np.float32)\n",
    "        kf0.correct(x0y0)\n",
    "        pred_x0y0 = kf0.predict()\n",
    "        kf1.correct(x1y1)\n",
    "        pred_x1y1 = kf1.predict()\n",
    "        x0 = np.max([0, pred_x0y0[0][0]]).astype(np.int)\n",
    "        x1 = np.min([shape[0], pred_x1y1[0][0]]).astype(np.int)\n",
    "        y0 = np.max([0, pred_x0y0[1][0]]).astype(np.int)\n",
    "        y1 = np.min([shape[1], pred_x1y1[1][0]]).astype(np.int)\n",
    "        if x0 == x1 or y0 == y1:\n",
    "            x0, y0, x1, y1 = prev_x0, prev_y0, prev_x1, prev_y1\n",
    "    return x0, x1, y0, y1     \n",
    "    \n",
    "def set_global_coord(x0, x1, y0, y1):\n",
    "    global prev_x0, prev_x1, prev_y0, prev_y1\n",
    "    prev_x0 = x0\n",
    "    prev_x1 = x1\n",
    "    prev_y1 = y1\n",
    "    prev_y0 = y0\n",
    "    \n",
    "def generate_face(ae_input, path_abgr, roi_size, roi_image):\n",
    "    result = np.squeeze(np.array([path_abgr([[ae_input]])]))\n",
    "    result_a = result[:,:,0] * 255\n",
    "    result_bgr = np.clip( (result[:,:,1:] + 1) * 255 / 2, 0, 255 )\n",
    "    result_a_clear = np.copy(result_a)\n",
    "    result_a = cv2.GaussianBlur(result_a ,(7,7),6)\n",
    "    if use_landmark_match and False:\n",
    "        resized_roi = cv2.resize(roi_image, (64,64))\n",
    "        result_bgr, result_a = landmarks_match_mtcnn(resized_roi, result_bgr, result_a)\n",
    "    result_a = np.expand_dims(result_a, axis=2)\n",
    "    result = (result_a/255 * result_bgr + (1 - result_a/255) * ((ae_input + 1) * 255 / 2)).astype('uint8')\n",
    "    if use_color_correction:\n",
    "        result = color_hist_match(result, roi_image)\n",
    "    result = cv2.cvtColor(result.astype('uint8'), cv2.COLOR_BGR2RGB)\n",
    "    result = cv2.resize(result, (roi_size[1],roi_size[0]))\n",
    "    result_a_clear = np.expand_dims(cv2.resize(result_a_clear, (roi_size[1],roi_size[0])), axis=2)\n",
    "    return result, result_a_clear\n",
    "\n",
    "def get_init_mask_map(image):\n",
    "    return np.zeros_like(image)\n",
    "\n",
    "def get_init_comb_img(input_img):\n",
    "    comb_img = np.zeros([input_img.shape[0], input_img.shape[1]*2,input_img.shape[2]])\n",
    "    comb_img[:, :input_img.shape[1], :] = input_img\n",
    "    comb_img[:, input_img.shape[1]:, :] = input_img\n",
    "    return comb_img    \n",
    "\n",
    "def get_init_triple_img(input_img, no_face=False):\n",
    "    if no_face:\n",
    "        triple_img = np.zeros([input_img.shape[0], input_img.shape[1]*3,input_img.shape[2]])\n",
    "        triple_img[:, :input_img.shape[1], :] = input_img\n",
    "        triple_img[:, input_img.shape[1]:input_img.shape[1]*2, :] = input_img      \n",
    "        triple_img[:, input_img.shape[1]*2:, :] = (input_img * .15).astype('uint8')  \n",
    "        return triple_img\n",
    "    else:\n",
    "        triple_img = np.zeros([input_img.shape[0], input_img.shape[1]*3,input_img.shape[2]])\n",
    "        return triple_img\n",
    "\n",
    "def get_mask(roi_image, h, w):\n",
    "    mask = np.zeros_like(roi_image)\n",
    "    mask[h//15:-h//15,w//15:-w//15,:] = 255\n",
    "    mask = cv2.GaussianBlur(mask,(15,15),10)\n",
    "    return mask\n",
    "\n",
    "def hist_match(source, template):\n",
    "    # Code borrow from:\n",
    "    # https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x\n",
    "    oldshape = source.shape\n",
    "    source = source.ravel()\n",
    "    template = template.ravel()\n",
    "    s_values, bin_idx, s_counts = np.unique(source, return_inverse=True,\n",
    "                                            return_counts=True)\n",
    "    t_values, t_counts = np.unique(template, return_counts=True)\n",
    "\n",
    "    s_quantiles = np.cumsum(s_counts).astype(np.float64)\n",
    "    s_quantiles /= s_quantiles[-1]\n",
    "    t_quantiles = np.cumsum(t_counts).astype(np.float64)\n",
    "    t_quantiles /= t_quantiles[-1]\n",
    "    interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)\n",
    "\n",
    "    return interp_t_values[bin_idx].reshape(oldshape)\n",
    "\n",
    "def color_hist_match(src_im, tar_im):\n",
    "    #src_im = cv2.cvtColor(src_im, cv2.COLOR_BGR2Lab)\n",
    "    #tar_im = cv2.cvtColor(tar_im, cv2.COLOR_BGR2Lab)\n",
    "    matched_R = hist_match(src_im[:,:,0], tar_im[:,:,0])\n",
    "    matched_G = hist_match(src_im[:,:,1], tar_im[:,:,1])\n",
    "    matched_B = hist_match(src_im[:,:,2], tar_im[:,:,2])\n",
    "    matched = np.stack((matched_R, matched_G, matched_B), axis=2).astype(np.float64)\n",
    "    return matched\n",
    "\n",
    "def landmarks_match_mtcnn(source, target, alpha):\n",
    "    global prev_pnts1, prev_pnts2\n",
    "    ratio = 0.2\n",
    "    \"\"\"\n",
    "    TODO: Reuse the landmarks of source image. Conceivable bug: coordinate mismatch.\n",
    "    \"\"\"\n",
    "    minsize = 20 # minimum size of face\n",
    "    threshold = [ 0.6, 0.7, 0.93 ]  # three steps's threshold\n",
    "    factor = 0.709 # scale factor\n",
    "    _, pnts1 = mtcnn_detect_face.detect_face(source, minsize, pnet, rnet, onet, threshold, factor) # redundant detection\n",
    "    _, pnts2 = mtcnn_detect_face.detect_face(target, minsize, pnet, rnet, onet, threshold, factor)  \n",
    "    \n",
    "    if len(prev_pnts1) == 0 and len(prev_pnts2) == 0:\n",
    "        if pnts1.shape[0] == 10 and pnts2.shape[0] == 10:\n",
    "            prev_pnts1, prev_pnts2 = pnts1, pnts2        \n",
    "    try:\n",
    "        landmarks_XY1 = []\n",
    "        landmarks_XY2 = []\n",
    "        for i in range(5):\n",
    "            landmarks_XY1.extend([((1-ratio)*pnts1[i+5][0] + ratio*prev_pnts1[i+5][0], \n",
    "                                   (1-ratio)*pnts1[i][0] + ratio*prev_pnts1[i][0])])\n",
    "            landmarks_XY2.extend([((1-ratio)*pnts2[i+5][0] + ratio*prev_pnts2[i+5][0], \n",
    "                                   (1-ratio)*pnts2[i][0] + ratio*prev_pnts2[i][0])])\n",
    "        M = umeyama(np.array(landmarks_XY1), np.array(landmarks_XY2), True)[0:2]\n",
    "        result = cv2.warpAffine(source, M, (64, 64), borderMode=cv2.BORDER_REPLICATE)  \n",
    "        mask = np.stack([alpha, alpha, alpha], axis=2)\n",
    "        assert len(mask.shape) == 3, \"len(mask.shape) is \" + str(len(mask.shape))\n",
    "        mask = cv2.warpAffine(mask, M, (64, 64), borderMode=cv2.BORDER_REPLICATE) \n",
    "        prev_landmarks_XY1, prev_landmarks_XY2 = landmarks_XY1, landmarks_XY2\n",
    "        return result, mask[:,:,0].astype(np.float32)\n",
    "    except:\n",
    "        return source, alpha\n",
    "\n",
    "def process_video(input_img): \n",
    "    global prev_x0, prev_x1, prev_y0, prev_y1\n",
    "    global frames      \n",
    "    global pnet, rnet, onet\n",
    "    \"\"\"\n",
    "    The following if statement is meant to solve a bug that has an unknow cause.\n",
    "    \"\"\"\n",
    "    if frames <= 2:\n",
    "        with tf.variable_scope('pnet2', reuse=True):\n",
    "            pnet2 = None\n",
    "            data = tf.placeholder(tf.float32, (None,None,None,3), 'input')\n",
    "            pnet2 = mtcnn_detect_face.PNet({'data':data})\n",
    "            pnet2.load(os.path.join(\"./mtcnn_weights/\", 'det1.npy'), sess)\n",
    "        with tf.variable_scope('rnet2', reuse=True):\n",
    "            rnet2 = None\n",
    "            data = tf.placeholder(tf.float32, (None,24,24,3), 'input')\n",
    "            rnet2 = mtcnn_detect_face.RNet({'data':data})\n",
    "            rnet2.load(os.path.join(\"./mtcnn_weights/\", 'det2.npy'), sess)\n",
    "        with tf.variable_scope('onet2', reuse=True):\n",
    "            onet2 = None\n",
    "            data = tf.placeholder(tf.float32, (None,48,48,3), 'input')\n",
    "            onet2 = mtcnn_detect_face.ONet({'data':data})\n",
    "            onet2.load(os.path.join(\"./mtcnn_weights/\", 'det3.npy'), sess)\n",
    "        pnet = K.function([pnet2.layers['data']],\n",
    "                          [pnet2.layers['conv4-2'], \n",
    "                           pnet2.layers['prob1']])\n",
    "        rnet = K.function([rnet2.layers['data']],\n",
    "                          [rnet2.layers['conv5-2'], \n",
    "                           rnet2.layers['prob1']])\n",
    "        onet = K.function([onet2.layers['data']], \n",
    "                          [onet2.layers['conv6-2'], \n",
    "                           onet2.layers['conv6-3'], \n",
    "                           onet2.layers['prob1']])\n",
    "    \n",
    "    image = input_img\n",
    "    faces = get_faces_bbox(image) # faces: face bbox coord\n",
    "    \n",
    "    if len(faces) == 0:\n",
    "        comb_img = get_init_comb_img(input_img)\n",
    "        triple_img = get_init_triple_img(input_img, no_face=True)\n",
    "    else:\n",
    "        faces = remove_overlaps(faces) # Has non-max suppress already been implemented in MTCNN?\n",
    "        \n",
    "    mask_map = get_init_mask_map(image)\n",
    "    comb_img = get_init_comb_img(input_img)\n",
    "    best_conf_score = 0\n",
    "    \n",
    "    for (x0, y1, x1, y0, conf_score) in faces:    \n",
    "        #print (x0, y1, x1, y0, conf_score)\n",
    "        # smoothing bounding box\n",
    "        if use_smoothed_bbox:\n",
    "            if frames != 0 and conf_score >= best_conf_score:\n",
    "                x0, x1, y0, y1 = get_smoothed_coord(x0, x1, y0, y1, \n",
    "                                                    image.shape, \n",
    "                                                    ratio=0.65 if use_kalman_filter else bbox_moving_avg_coef)\n",
    "                set_global_coord(x0, x1, y0, y1)\n",
    "                best_conf_score = conf_score\n",
    "                frames += 1\n",
    "            elif conf_score <= best_conf_score:\n",
    "                frames += 1\n",
    "            else:\n",
    "                if conf_score >= best_conf_score:\n",
    "                    set_global_coord(x0, x1, y0, y1)\n",
    "                    best_conf_score = conf_score\n",
    "                if use_kalman_filter:\n",
    "                    for i in range(200):\n",
    "                        kf0.predict()\n",
    "                        kf1.predict()\n",
    "                frames += 1\n",
    "                \n",
    "        h = int(x1 - x0)\n",
    "        w = int(y1 - y0)\n",
    "        roi_coef = 15\n",
    "        roi_x0, roi_x1, roi_y0, roi_y1 = int(x0+h//roi_coef), int(x1-h//roi_coef), int(y0+w//roi_coef), int(y1-w//roi_coef)\n",
    "            \n",
    "        cv2_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
    "        roi_image = cv2_img[roi_x0:roi_x1,roi_y0:roi_y1,:]\n",
    "        roi_size = roi_image.shape  \n",
    "        \n",
    "        ae_input = cv2.resize(roi_image, (64,64))/255. * 2 - 1  \n",
    "        np.squeeze(np.array([path_abgr_A([[ae_input]])]))\n",
    "        result, result_a = generate_face(ae_input, path_func, roi_size, roi_image)\n",
    "        if conf_score >= best_conf_score:\n",
    "            mask_map[roi_x0:roi_x1,roi_y0:roi_y1,:] = result_a\n",
    "            mask_map = np.clip(mask_map + .15 * input_img, 0, 255)     \n",
    "        else:\n",
    "            mask_map[roi_x0:roi_x1,roi_y0:roi_y1,:] += result_a\n",
    "            mask_map = np.clip(mask_map, 0, 255)\n",
    "        \n",
    "        if use_smoothed_mask:\n",
    "            mask = get_mask(roi_image, h, w)\n",
    "            roi_rgb = cv2.cvtColor(roi_image, cv2.COLOR_BGR2RGB)\n",
    "            smoothed_result = mask/255 * result + (1-mask/255) * roi_rgb\n",
    "            comb_img[roi_x0:roi_x1, input_img.shape[1]+roi_y0:input_img.shape[1]+roi_y1,:] = smoothed_result\n",
    "        else:\n",
    "            comb_img[roi_x0:roi_x1, input_img.shape[1]+roi_y0:input_img.shape[1]+roi_y1,:] = result\n",
    "            \n",
    "        triple_img = get_init_triple_img(input_img)\n",
    "        triple_img[:, :input_img.shape[1]*2, :] = comb_img\n",
    "        triple_img[:, input_img.shape[1]*2:, :] = mask_map\n",
    "        \n",
    "    global output_type\n",
    "    if output_type == 1:\n",
    "        return comb_img[:, input_img.shape[1]:, :]  # return only result image\n",
    "    elif output_type == 2:\n",
    "        return comb_img  # return input and result image combined as one\n",
    "    elif output_type == 3:\n",
    "        return triple_img #return input,result and mask heatmap image combined as one"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Video making config\n",
    "\n",
    "**Description**\n",
    "```python\n",
    "    video_scaling_offset = 0 # Increase by 1 if OOM happens.\n",
    "    manually_downscale = False # Set True if increasing offset doesn't help\n",
    "    manually_downscale_factor = int(2) # Increase by 1 if OOM still happens.\n",
    "    use_color_correction = False # Option for color corretion\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_kalman_filter = False\n",
    "if use_kalman_filter:\n",
    "    noise_coef = 5e-3 # Increase by 10x if tracking is slow. \n",
    "    kf0 = kalmanfilter_init(noise_coef)\n",
    "    kf1 = kalmanfilter_init(noise_coef)\n",
    "else:\n",
    "    bbox_moving_avg_coef = 0.65\n",
    "    \n",
    "video_scaling_offset = 0 \n",
    "manually_downscale = False\n",
    "manually_downscale_factor = int(2) # should be an positive integer\n",
    "use_color_correction = False\n",
    "use_landmark_match = False # Under developement, This is not functioning.\n",
    "\n",
    "# ========== Change the following line for different output type==========\n",
    "# Output type: \n",
    "#    1. [ result ] \n",
    "#    2. [ source | result ] \n",
    "#    3. [ source | result | mask ]\n",
    "global output_type\n",
    "output_type = 3\n",
    "\n",
    "# Detection threshold:  a float point between 0 and 1. Decrease this value if faces are missed.\n",
    "global detec_threshold\n",
    "detec_threshold = 0.7"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate video clip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[MoviePy] >>>> Building video OUTPUT_VIDEO.mp4\n",
      "[MoviePy] Writing video OUTPUT_VIDEO.mp4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 1/91 [00:06<10:09,  6.78s/it]"
     ]
    }
   ],
   "source": [
    "# Variables for smoothing bounding box\n",
    "global prev_x0, prev_x1, prev_y0, prev_y1\n",
    "global frames\n",
    "global prev_pnts1, prev_pnts2\n",
    "prev_x0 = prev_x1 = prev_y0 = prev_y1 = 0\n",
    "frames = 0\n",
    "prev_pnts1 = prev_pnts2 = np.array([])\n",
    "\n",
    "output = 'OUTPUT_VIDEO.mp4'\n",
    "clip1 = VideoFileClip(\"INPUT_VIDEO.mp4\")\n",
    "clip = clip1.fl_image(process_video)#.subclip(7.5, 9) #NOTE: this function expects color images!!\n",
    "%time clip.write_videofile(output, audio=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gc\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
